import json

from pathlib import Path


def main():
    score_load_path = score_dir / Path(score_load_name)
    with open(score_load_path, mode='r', encoding='utf-8') as file:
        dataset = json.load(file)['data']

    clinical_department_to_model_name_dict = {}
    for clinical_department in clinical_department_zh_list:
        print(clinical_department + ':')
        sorted_score = sorted(dataset, key=lambda x: x[clinical_department]['weighted_average'])
        output_score = [(item['model'], item[clinical_department]['weighted_average']) for item in sorted_score]
        print(output_score[-1])
        clinical_department_to_model_name_dict[clinical_department] = model_name_mapping_dict[output_score[-1][0]]
    print(clinical_department_to_model_name_dict)

    clinical_cases_list = []
    for (clinical_department, model_name) in clinical_department_to_model_name_dict.items():
        print(clinical_department)
        print(model_name)
        inference_load_name = f'inference_{language}_{model_name}.json'
        inference_load_path = inference_dir / Path(inference_load_name)
        with open(inference_load_path, mode='r', encoding='utf-8') as file:
            dataset = json.load(file)
        sub_clinical_cases_list = [item for item in dataset if item['clinical_department'] == clinical_department]
        for sub_clinical_case in sub_clinical_cases_list:
            clinical_cases_list.append(sub_clinical_case)

    print(len(clinical_cases_list))
    # print(clinical_cases_list)
    merged_save_path = inference_dir / Path(merged_save_name)
    with open(merged_save_path, mode='w', encoding='utf-8') as file:
        json.dump(clinical_cases_list, file, ensure_ascii=False, indent=2)


if __name__ == '__main__':

    language = 'zh'
    score_load_name = 'score_weighted_average(2024-05-27).json'
    merged_save_name = f'inference_{language}_agent@1@1.json'

    model_name_mapping_dict = {
        'baichuan2chat': 'Baichuan2-13B-Chat',  # https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat
        'bianque2': 'BianQue-2',  # https://huggingface.co/scutcyr/BianQue-2
        'bluelmchat': 'BlueLM-7B-Chat',  # https://huggingface.co/vivo-ai/BlueLM-7B-Chat
        'chatglm3': 'ChatGLM3-6B',  # https://huggingface.co/THUDM/chatglm3-6b
        'claude3': 'Claude-3',  # https://www.anthropic.com/news/claude-3-haiku
        'discmedllm': 'DISC-MedLLM',  # https://huggingface.co/Flmc/DISC-MedLLM
        'geminipro': 'Gemini-Pro',  # https://ai.google.dev/models/gemini
        'gpt3.5': 'GPT-3.5',  # https://platform.openai.com/docs/models/gpt-3-5
        'gpt4': 'GPT-4',  # https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo
        'huatuogpt2': 'HuatuoGPT2-34B',  # https://huggingface.co/FreedomIntelligence/HuatuoGPT2-34B
        'internlm2chat': 'InternLM2-20B-Chat',  # https://huggingface.co/internlm/internlm2-chat-20b
        'pulse': 'PULSE-20B',  # https://huggingface.co/OpenMEDLab/PULSE-20bv5
        'qwenchat': 'Qwen-72B-Chat',  # https://huggingface.co/Qwen/Qwen-72B-Chat
        'spark3': 'Spark-3',  # https://xinghuo.xfyun.cn/
        'taiyillm': 'Taiyi-LLM',  # https://huggingface.co/DUTIR-BioNLP/Taiyi-LLM
        'wingpt2': 'WiNGPT2-14B-Chat',  # https://huggingface.co/winninghealth/WiNGPT2-14B-Chat
        'yichat': 'Yi-34B-Chat',  # https://huggingface.co/01-ai/Yi-34B-Chat
    }
    model_name_mapping_dict = {value: key for key, value in model_name_mapping_dict.items()}

    clinical_department_zh_to_en_dict = {
        '乳腺外科': 'breast surgical department',
        '产科': 'obstetrics department',
        '儿科': 'pediatrics department',
        '内分泌内科': 'endocrinology department',
        '呼吸内科': 'respiratory medicine department',
        '妇科': 'gynecology department',
        '心脏外科': 'cardiac surgical department',
        '心血管内科': 'cardiovascular medicine department',
        '泌尿外科': 'urinary surgical department',
        '消化内科': 'gastroenterology department',
        '甲状腺外科': 'thyroid surgical department',
        '疝外科': 'hernia surgical department',
        '神经内科': 'neurology department',
        '神经外科': 'neurosurgery department',
        '耳鼻咽喉头颈外科': 'otolaryngology head and neck surgical department',
        '肛门结直肠外科': 'anus and intestine surgical department',
        '肝胆胰外科': 'hepatobiliary and pancreas surgical department',
        '肾内科': 'nephrology department',
        '胃肠外科': 'gastrointestinal surgical department',
        '胸外科': 'thoracic surgical department',
        '血液内科': 'hematology department',
        '血管外科': 'vascular surgical department',
        '骨科': 'orthopedics department',
    }
    clinical_department_zh_list = list(clinical_department_zh_to_en_dict.keys())
    clinical_department_en_list = list(clinical_department_zh_to_en_dict.values())

    inference_dir = Path(__file__).parent.parent / Path('inferences')
    if not inference_dir.is_dir():
        inference_dir.mkdir(parents=True, exist_ok=True)
    score_dir = Path(__file__).parent.parent / Path('scores')
    if not score_dir.is_dir():
        score_dir.mkdir(parents=True, exist_ok=True)

    main()
